{% extends 'base.attached.class' %}

{% block content %}
class {{ class_name }} {

    {% if is_binary %}
    private double[] coeffs;
    private double inters;

    public {{ class_name }}(double[] coeffs, double inters) {
        this.coeffs = coeffs;
        this.inters = inters;
    }
    {% else %}
    private double[][] coeffs;
    private double[] inters;

    public {{ class_name }}(double[][] coeffs, double[] inters) {
        this.coeffs = coeffs;
        this.inters = inters;
    }
    {% endif %}

    {% if is_binary %}
    public int predict(double[] features) {
        double prob = 0.;
        for (int i = 0, il = this.coeffs.length; i < il; i++) {
            prob += this.coeffs[i] * features[i];
        }
        if (prob + this.inters > 0) {
            return 1;
        }
        return 0;
    }
    {% else %}
    public int predict(double[] features) {
        int classIdx = 0;
        double classVal = Double.NEGATIVE_INFINITY;
        for (int i = 0, il = this.inters.length; i < il; i++) {
            double prob = 0.;
            for (int j = 0, jl = this.coeffs[0].length; j < jl; j++) {
                prob += this.coeffs[i][j] * features[j];
            }
            if (prob + this.inters[i] > classVal) {
                classVal = prob + this.inters[i];
                classIdx = i;
            }
        }
        return classIdx;
    }
    {% endif %}

    public static void main(String[] args) {
        int nFeatures = {{ n_features }};
        if (args.length != nFeatures) {
            throw new IllegalArgumentException("You have to pass " +  String.valueOf(nFeatures) + " features.");
        }

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Model data:
        {{ coeffs }}
        {{ inters }}

        // Estimator:
        {{ class_name }} clf = new {{ class_name }}(coeffs, inters);

        {% if is_test or to_json %}
        // Get JSON:
        int prediction = clf.predict(features);
        System.out.println("{\"predict\": " + String.valueOf(prediction) + "}");
        {% else %}
        // Get class prediction:
        int prediction = clf.predict(features);
        System.out.println("Predicted class: #" + String.valueOf(prediction));
        {% endif %}

    }
}
{% endblock %}